!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief Routines for calculating a complex matrix exponential.
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************

MODULE matrix_exp

   USE cp_cfm_basic_linalg,             ONLY: cp_cfm_scale_and_add,&
                                              cp_cfm_solve
   USE cp_cfm_types,                    ONLY: cp_cfm_create,&
                                              cp_cfm_release,&
                                              cp_cfm_set_all,&
                                              cp_cfm_to_cfm,&
                                              cp_cfm_type
   USE cp_fm_basic_linalg,              ONLY: cp_complex_fm_gemm,&
                                              cp_fm_scale_and_add,&
                                              cp_fm_solve
   USE cp_fm_struct,                    ONLY: cp_fm_struct_double,&
                                              cp_fm_struct_release,&
                                              cp_fm_struct_type
   USE cp_fm_types,                     ONLY: cp_fm_create,&
                                              cp_fm_get_info,&
                                              cp_fm_release,&
                                              cp_fm_set_all,&
                                              cp_fm_to_fm,&
                                              cp_fm_type
   USE cp_log_handling,                 ONLY: cp_to_string
   USE kinds,                           ONLY: dp
   USE mathconstants,                   ONLY: fac,&
                                              z_one,&
                                              z_zero
   USE message_passing,                 ONLY: mp_comm_type,&
                                              mp_para_env_type
   USE parallel_gemm_api,               ONLY: parallel_gemm
#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'matrix_exp'

   PUBLIC :: taylor_only_imaginary, &
             taylor_full_complex, &
             exp_pade_full_complex, &
             exp_pade_only_imaginary, &
             get_nsquare_norder, &
             arnoldi, exp_pade_real

CONTAINS

! **************************************************************************************************
!> \brief specialized subroutine for purely imaginary matrix exponentials
!> \param exp_H ...
!> \param im_matrix ...
!> \param nsquare ...
!> \param ntaylor ...
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************

   SUBROUTINE taylor_only_imaginary(exp_H, im_matrix, nsquare, ntaylor)
      TYPE(cp_fm_type), DIMENSION(2)                     :: exp_H
      TYPE(cp_fm_type), INTENT(IN)                       :: im_matrix
      INTEGER, INTENT(in)                                :: nsquare, ntaylor

      CHARACTER(len=*), PARAMETER :: routineN = 'taylor_only_imaginary'
      REAL(KIND=dp), PARAMETER                           :: one = 1.0_dp, zero = 0.0_dp

      INTEGER                                            :: handle, i, ndim, nloop
      REAL(KIND=dp)                                      :: square_fac, Tfac, tmp
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
         POINTER                                         :: local_data_im
      TYPE(cp_fm_type)                                   :: T1, T2, Tres_im, Tres_re

      CALL timeset(routineN, handle)

      CALL cp_fm_get_info(im_matrix, local_data=local_data_im)
      ndim = im_matrix%matrix_struct%nrow_global

      square_fac = 1.0_dp/(2.0_dp**REAL(nsquare, dp))
!    CALL cp_fm_scale(square_fac,im_matrix)
      CALL cp_fm_create(T1, &
                        matrix_struct=im_matrix%matrix_struct, &
                        name="T1")

      CALL cp_fm_create(T2, &
                        matrix_struct=T1%matrix_struct, &
                        name="T2")
      CALL cp_fm_create(Tres_im, &
                        matrix_struct=T1%matrix_struct, &
                        name="T3")
      CALL cp_fm_create(Tres_re, &
                        matrix_struct=T1%matrix_struct, &
                        name="Tres")
      tmp = 1.0_dp

      CALL cp_fm_set_all(Tres_re, zero, one)
      CALL cp_fm_set_all(Tres_im, zero, zero)
      CALL cp_fm_set_all(T1, zero, one)

      Tfac = one
      nloop = CEILING(REAL(ntaylor, dp)/2.0_dp)

      DO i = 1, nloop
         tmp = tmp*(REAL(i, dp)*2.0_dp - 1.0_dp)
         CALL parallel_gemm("N", "N", ndim, ndim, ndim, square_fac, im_matrix, T1, zero, &
                            !       CALL parallel_gemm("N","N",ndim,ndim,ndim,one,im_matrix,T1,zero,&
                            T2)
         Tfac = 1._dp/tmp
         IF (MOD(i, 2) == 0) Tfac = -Tfac
         CALL cp_fm_scale_and_add(one, Tres_im, Tfac, T2)
         tmp = tmp*REAL(i, dp)*2.0_dp
         CALL parallel_gemm("N", "N", ndim, ndim, ndim, square_fac, im_matrix, T2, zero, &
                            !       CALL parallel_gemm("N","N",ndim,ndim,ndim,one,im_matrix,T2,zero,&
                            T1)
         Tfac = 1._dp/tmp
         IF (MOD(i, 2) == 1) Tfac = -Tfac
         CALL cp_fm_scale_and_add(one, Tres_re, Tfac, T1)

      END DO

      IF (nsquare > 0) THEN
         DO i = 1, nsquare
            CALL cp_complex_fm_gemm("N", "N", ndim, ndim, ndim, one, Tres_re, Tres_im, &
                                    Tres_re, Tres_im, zero, exp_H(1), exp_H(2))

            CALL cp_fm_to_fm(exp_H(1), Tres_re)
            CALL cp_fm_to_fm(exp_H(2), Tres_im)
         END DO
      ELSE
         CALL cp_fm_to_fm(Tres_re, exp_H(1))
         CALL cp_fm_to_fm(Tres_im, exp_H(2))
      END IF

      CALL cp_fm_release(T1)
      CALL cp_fm_release(T2)
      CALL cp_fm_release(Tres_re)
      CALL cp_fm_release(Tres_im)

      CALL timestop(handle)

   END SUBROUTINE taylor_only_imaginary

! **************************************************************************************************
!> \brief subroutine for general complex matrix exponentials
!>        on input a separate cp_fm_type for real and complex part
!>        on output a size 2 cp_fm_type, first element is the real part of
!>        the exponential second the imaginary
!> \param exp_H ...
!> \param re_part ...
!> \param im_part ...
!> \param nsquare ...
!> \param ntaylor ...
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************

   SUBROUTINE taylor_full_complex(exp_H, re_part, im_part, nsquare, ntaylor)
      TYPE(cp_fm_type), DIMENSION(2)                     :: exp_H
      TYPE(cp_fm_type), INTENT(IN)                       :: re_part, im_part
      INTEGER, INTENT(in)                                :: nsquare, ntaylor

      CHARACTER(len=*), PARAMETER :: routineN = 'taylor_full_complex'

      COMPLEX(KIND=dp)                                   :: Tfac
      INTEGER                                            :: handle, i, ndim
      REAL(KIND=dp)                                      :: square_fac, tmp
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
         POINTER                                         :: local_data_im, local_data_re
      TYPE(cp_cfm_type)                                  :: T1, T2, T3, Tres

      CALL timeset(routineN, handle)
      CALL cp_fm_get_info(re_part, local_data=local_data_re)
      CALL cp_fm_get_info(im_part, local_data=local_data_im)
      ndim = re_part%matrix_struct%nrow_global

      CALL cp_cfm_create(T1, &
                         matrix_struct=re_part%matrix_struct, &
                         name="T1")

      square_fac = 2.0_dp**REAL(nsquare, dp)

      T1%local_data = CMPLX(local_data_re/square_fac, local_data_im/square_fac, KIND=dp)

      CALL cp_cfm_create(T2, &
                         matrix_struct=T1%matrix_struct, &
                         name="T2")
      CALL cp_cfm_create(T3, &
                         matrix_struct=T1%matrix_struct, &
                         name="T3")
      CALL cp_cfm_create(Tres, &
                         matrix_struct=T1%matrix_struct, &
                         name="Tres")
      tmp = 1.0_dp
      CALL cp_cfm_set_all(Tres, z_zero, z_one)
      CALL cp_cfm_set_all(T2, z_zero, z_one)
      Tfac = z_one

      DO i = 1, ntaylor
         tmp = tmp*REAL(i, dp)
         CALL parallel_gemm("N", "N", ndim, ndim, ndim, z_one, T1, T2, z_zero, &
                            T3)
         Tfac = CMPLX(1._dp/tmp, 0.0_dp, kind=dp)
         CALL cp_cfm_scale_and_add(z_one, Tres, Tfac, T3)
         CALL cp_cfm_to_cfm(T3, T2)
      END DO

      IF (nsquare > 0) THEN
         DO i = 1, nsquare
            CALL parallel_gemm("N", "N", ndim, ndim, ndim, z_one, Tres, Tres, z_zero, &
                               T2)
            CALL cp_cfm_to_cfm(T2, Tres)
         END DO
      END IF

      exp_H(1)%local_data = REAL(Tres%local_data, KIND=dp)
      exp_H(2)%local_data = AIMAG(Tres%local_data)

      CALL cp_cfm_release(T1)
      CALL cp_cfm_release(T2)
      CALL cp_cfm_release(T3)
      CALL cp_cfm_release(Tres)
      CALL timestop(handle)

   END SUBROUTINE taylor_full_complex

! **************************************************************************************************
!> \brief optimization function for pade/taylor order and number of squaring steps
!> \param norm ...
!> \param nsquare ...
!> \param norder ...
!> \param eps_exp ...
!> \param method ...
!> \param do_emd ...
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************
   SUBROUTINE get_nsquare_norder(norm, nsquare, norder, eps_exp, method, do_emd)

      REAL(dp), INTENT(in)                               :: norm
      INTEGER, INTENT(out)                               :: nsquare, norder
      REAL(dp), INTENT(in)                               :: eps_exp
      INTEGER, INTENT(in)                                :: method
      LOGICAL, INTENT(in)                                :: do_emd

      INTEGER                                            :: cost, i, iscale, orders(3), p, &
                                                            prev_cost, q
      LOGICAL                                            :: new_scale
      REAL(dp)                                           :: D, eval, myval, N, scaleD, scaleN

      orders(:) = [12, 12, 12]
      IF (method == 2) THEN
         DO iscale = 0, 12
            new_scale = .FALSE.
            eval = norm/(2.0_dp**REAL(iscale, dp))
            DO q = 1, 12
               DO p = MAX(1, q - 1), q
                  IF (p > q) EXIT
                  D = 1.0_dp
                  N = 1.0_dp
                  DO i = 1, q
                     IF (i <= p) scaleN = fac(p + q - i)*fac(p)/(fac(p + q)*fac(i)*fac(p - i))
                     scaleD = (-1.0)**i*fac(p + q - i)*fac(q)/(fac(p + q)*fac(i)*fac(q - i))
                     IF (i <= p) N = N + scaleN*eval**i
                     D = D + scaleD*eval**i
                  END DO
                  IF (ABS((EXP(norm) - (N/D)**(2.0_dp**iscale))/MAX(1.0_dp, EXP(norm))) <= eps_exp) THEN
                     IF (do_emd) THEN
                        cost = iscale + q
                        prev_cost = orders(1) + orders(2)
                     ELSE
                        cost = iscale + CEILING(REAL(q, dp)/3.0_dp)
                        prev_cost = orders(1) + CEILING(REAL(orders(2), dp)/3.0_dp)
                     END IF
                     IF (cost < prev_cost) THEN
                        orders(:) = [iscale, q, p]
                        myval = (N/D)**(2.0_dp**iscale)
                     END IF
                     new_scale = .TRUE.
                     EXIT
                  END IF
               END DO
               IF (new_scale) EXIT
            END DO
            IF (iscale >= orders(1) + orders(2) .AND. new_scale) EXIT
         END DO
      ELSE IF (method == 1) THEN
         q = 0
         eval = norm
         DO iscale = 0, 6
            new_scale = .FALSE.
            IF (iscale >= 1) eval = norm/(2.0_dp**REAL(iscale, dp))
            DO p = 1, 20
               D = 1.0_dp
               N = 1.0_dp
               DO i = 1, p
                  scaleN = 1.0_dp/fac(i)
                  N = N + scaleN*(eval**REAL(i, dp))
               END DO
               IF (ABS((EXP(norm) - N**(2.0_dp**REAL(iscale, dp)))/MAX(1.0_dp, EXP(norm))) <= eps_exp) THEN
                  IF (do_emd) THEN
                     cost = iscale + p
                     prev_cost = orders(1) + orders(2)
                  ELSE
                     cost = iscale + CEILING(REAL(p, dp)/3.0_dp)
                     prev_cost = orders(1) + CEILING(REAL(orders(2), dp)/3.0_dp)
                  END IF
                  IF (cost < prev_cost) THEN
                     orders(:) = [iscale, p, 0]
                     myval = (N)**(2.0_dp**iscale)
                  END IF
                  new_scale = .TRUE.
                  EXIT
               END IF
            END DO
            IF (iscale >= orders(1) + orders(2) .AND. new_scale) EXIT
         END DO
      END IF

      nsquare = orders(1)
      norder = orders(2)

   END SUBROUTINE get_nsquare_norder

! **************************************************************************************************
!> \brief exponential of a complex matrix,
!>        calculated using pade approximation together with scaling and squaring
!> \param exp_H ...
!> \param re_part ...
!> \param im_part ...
!> \param nsquare ...
!> \param npade ...
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************

   SUBROUTINE exp_pade_full_complex(exp_H, re_part, im_part, nsquare, npade)
      TYPE(cp_fm_type), DIMENSION(2)                     :: exp_H
      TYPE(cp_fm_type), INTENT(IN)                       :: re_part, im_part
      INTEGER, INTENT(in)                                :: nsquare, npade

      CHARACTER(len=*), PARAMETER :: routineN = 'exp_pade_full_complex'

      COMPLEX(KIND=dp)                                   :: scaleD, scaleN
      INTEGER                                            :: handle, i, ldim, ndim, p, q
      REAL(KIND=dp)                                      :: square_fac, tmp
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
         POINTER                                         :: local_data_im, local_data_re
      TYPE(cp_cfm_type)                                  :: Dpq, fin_p, T1
      TYPE(cp_cfm_type), DIMENSION(2)                    :: mult_p
      TYPE(cp_cfm_type), TARGET                          :: Npq, T2, Tres

      p = npade
      q = npade

      CALL timeset(routineN, handle)
      CALL cp_fm_get_info(re_part, local_data=local_data_re, ncol_local=ldim, &
                          nrow_global=ndim)
      CALL cp_fm_get_info(im_part, local_data=local_data_im)

      CALL cp_cfm_create(Dpq, &
                         matrix_struct=re_part%matrix_struct, &
                         name="Dpq")

      square_fac = 2.0_dp**REAL(nsquare, dp)

      CALL cp_cfm_create(T1, &
                         matrix_struct=Dpq%matrix_struct, &
                         name="T1")

      CALL cp_cfm_create(T2, &
                         matrix_struct=T1%matrix_struct, &
                         name="T2")
      CALL cp_cfm_create(Npq, &
                         matrix_struct=T1%matrix_struct, &
                         name="Npq")
      CALL cp_cfm_create(Tres, &
                         matrix_struct=T1%matrix_struct, &
                         name="Tres")

      DO i = 1, ldim
         T2%local_data(:, i) = CMPLX(local_data_re(:, i)/square_fac, local_data_im(:, i)/square_fac, KIND=dp)
      END DO
      CALL cp_cfm_to_cfm(T2, T1)
      mult_p(1) = T2
      mult_p(2) = Tres
      tmp = 1.0_dp
      CALL cp_cfm_set_all(Npq, z_zero, z_one)
      CALL cp_cfm_set_all(Dpq, z_zero, z_one)

      CALL cp_cfm_scale_and_add(z_one, Npq, z_one*0.5_dp, T2)
      CALL cp_cfm_scale_and_add(z_one, Dpq, -z_one*0.5_dp, T2)

      IF (npade > 2) THEN
         DO i = 2, npade
            IF (i <= p) scaleN = CMPLX(fac(p + q - i)*fac(p)/(fac(p + q)*fac(i)*fac(p - i)), 0.0_dp, kind=dp)
            scaleD = CMPLX((-1.0_dp)**i*fac(p + q - i)*fac(q)/(fac(p + q)*fac(i)*fac(q - i)), 0.0_dp, kind=dp)
            CALL parallel_gemm("N", "N", ndim, ndim, ndim, z_one, T1, mult_p(MOD(i, 2) + 1), z_zero, &
                               mult_p(MOD(i + 1, 2) + 1))
            IF (i <= p) CALL cp_cfm_scale_and_add(z_one, Npq, scaleN, mult_p(MOD(i + 1, 2) + 1))
            IF (i <= q) CALL cp_cfm_scale_and_add(z_one, Dpq, scaleD, mult_p(MOD(i + 1, 2) + 1))
         END DO
      END IF

      CALL cp_cfm_solve(Dpq, Npq)

      mult_p(2) = Npq
      mult_p(1) = Tres
      IF (nsquare > 0) THEN
         DO i = 1, nsquare
            CALL parallel_gemm("N", "N", ndim, ndim, ndim, z_one, mult_p(MOD(i, 2) + 1), mult_p(MOD(i, 2) + 1), z_zero, &
                               mult_p(MOD(i + 1, 2) + 1))
            fin_p = mult_p(MOD(i + 1, 2) + 1)
         END DO
      ELSE
         fin_p = Npq
      END IF
      DO i = 1, ldim
         exp_H(1)%local_data(:, i) = REAL(fin_p%local_data(:, i), KIND=dp)
         exp_H(2)%local_data(:, i) = AIMAG(fin_p%local_data(:, i))
      END DO

      CALL cp_cfm_release(Npq)
      CALL cp_cfm_release(Dpq)
      CALL cp_cfm_release(T1)
      CALL cp_cfm_release(T2)
      CALL cp_cfm_release(Tres)
      CALL timestop(handle)

   END SUBROUTINE exp_pade_full_complex

! **************************************************************************************************
!> \brief exponential of a complex matrix,
!>        calculated using pade approximation together with scaling and squaring
!> \param exp_H ...
!> \param im_part ...
!> \param nsquare ...
!> \param npade ...
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************

   SUBROUTINE exp_pade_only_imaginary(exp_H, im_part, nsquare, npade)
      TYPE(cp_fm_type), DIMENSION(2)                     :: exp_H
      TYPE(cp_fm_type), INTENT(IN)                       :: im_part
      INTEGER, INTENT(in)                                :: nsquare, npade

      CHARACTER(len=*), PARAMETER :: routineN = 'exp_pade_only_imaginary'
      REAL(KIND=dp), PARAMETER                           :: rone = 1.0_dp, rzero = 0.0_dp

      COMPLEX(KIND=dp)                                   :: scaleD, scaleN
      INTEGER                                            :: handle, i, j, k, ldim, ndim, p, q
      REAL(KIND=dp)                                      :: my_fac, square_fac
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
         POINTER                                         :: local_data_im
      TYPE(cp_cfm_type)                                  :: Dpq, fin_p
      TYPE(cp_cfm_type), DIMENSION(2)                    :: cmult_p
      TYPE(cp_cfm_type), TARGET                          :: Npq, T1
      TYPE(cp_fm_type)                                   :: T2, Tres

      CALL timeset(routineN, handle)
      p = npade
      q = npade !p==q seems to be necessary for the rest of the code

      CALL cp_fm_get_info(im_part, local_data=local_data_im, ncol_local=ldim, nrow_global=ndim)
      square_fac = 1.0_dp/(2.0_dp**REAL(nsquare, dp))

      CALL cp_cfm_create(Dpq, &
                         matrix_struct=im_part%matrix_struct, &
                         name="Dpq")

      CALL cp_cfm_create(Npq, &
                         matrix_struct=Dpq%matrix_struct, &
                         name="Npq")

      CALL cp_cfm_create(T1, &
                         matrix_struct=Dpq%matrix_struct, &
                         name="T1")

      CALL cp_fm_create(T2, &
                        matrix_struct=T1%matrix_struct, &
                        name="T2")

      CALL cp_fm_create(Tres, &
                        matrix_struct=T1%matrix_struct, &
                        name="Tres")

!    DO i=1,ldim
!       local_data_im(:,i)=local_data_im(:,i)/square_fac
!    END DO

      CALL cp_fm_to_fm(im_part, T2)

      CALL cp_cfm_set_all(Npq, z_zero, z_one)
      CALL cp_cfm_set_all(Dpq, z_zero, z_one)

      DO i = 1, ldim
         Npq%local_data(:, i) = Npq%local_data(:, i) + CMPLX(rzero, 0.5_dp*square_fac*local_data_im(:, i), dp)
         Dpq%local_data(:, i) = Dpq%local_data(:, i) - CMPLX(rzero, 0.5_dp*square_fac*local_data_im(:, i), dp)
      END DO

      IF (npade > 2) THEN
         DO j = 1, FLOOR(npade/2.0_dp)
            i = 2*j
            my_fac = (-rone)**j
            IF (i <= p) scaleN = CMPLX(my_fac*fac(p + q - i)*fac(p)/(fac(p + q)*fac(i)*fac(p - i)), 0.0_dp, dp)
            scaleD = CMPLX(my_fac*fac(p + q - i)*fac(q)/(fac(p + q)*fac(i)*fac(q - i)), 0.0_dp, dp)
            CALL parallel_gemm("N", "N", ndim, ndim, ndim, square_fac, im_part, T2, rzero, Tres)

            DO k = 1, ldim
               Npq%local_data(:, k) = Npq%local_data(:, k) + scaleN*Tres%local_data(:, k)
               Dpq%local_data(:, k) = Dpq%local_data(:, k) + scaleD*Tres%local_data(:, k)
            END DO

            IF (2*j + 1 <= q) THEN
               i = 2*j + 1
               IF (i <= p) scaleN = CMPLX(my_fac*fac(p + q - i)*fac(p)/(fac(p + q)*fac(i)*fac(p - i)), rzero, dp)
               scaleD = CMPLX(-my_fac*fac(p + q - i)*fac(q)/(fac(p + q)*fac(i)*fac(q - i)), rzero, dp)
               CALL parallel_gemm("N", "N", ndim, ndim, ndim, square_fac, im_part, Tres, rzero, T2)

               DO k = 1, ldim
                  Npq%local_data(:, k) = Npq%local_data(:, k) + scaleN*CMPLX(rzero, T2%local_data(:, k), dp)
                  Dpq%local_data(:, k) = Dpq%local_data(:, k) + scaleD*CMPLX(rzero, T2%local_data(:, k), dp)
               END DO
            END IF
         END DO
      END IF

      CALL cp_cfm_solve(Dpq, Npq)

      cmult_p(2) = Npq
      cmult_p(1) = T1
      IF (nsquare > 0) THEN
         DO i = 1, nsquare
            CALL parallel_gemm("N", "N", ndim, ndim, ndim, z_one, cmult_p(MOD(i, 2) + 1), cmult_p(MOD(i, 2) + 1), z_zero, &
                               cmult_p(MOD(i + 1, 2) + 1))
            fin_p = cmult_p(MOD(i + 1, 2) + 1)
         END DO
      ELSE
         fin_p = Npq
      END IF

      DO k = 1, ldim
         exp_H(1)%local_data(:, k) = REAL(fin_p%local_data(:, k), KIND=dp)
         exp_H(2)%local_data(:, k) = AIMAG(fin_p%local_data(:, k))
      END DO

      CALL cp_cfm_release(Npq)
      CALL cp_cfm_release(Dpq)
      CALL cp_cfm_release(T1)
      CALL cp_fm_release(T2)
      CALL cp_fm_release(Tres)
      CALL timestop(handle)
   END SUBROUTINE exp_pade_only_imaginary

! **************************************************************************************************
!> \brief exponential of a real matrix,
!>        calculated using pade approximation together with scaling and squaring
!> \param exp_H ...
!> \param matrix ...
!> \param nsquare ...
!> \param npade ...
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************

   SUBROUTINE exp_pade_real(exp_H, matrix, nsquare, npade)
      TYPE(cp_fm_type), INTENT(IN)                       :: exp_H, matrix
      INTEGER, INTENT(in)                                :: nsquare, npade

      CHARACTER(len=*), PARAMETER                        :: routineN = 'exp_pade_real'
      REAL(KIND=dp), PARAMETER                           :: one = 1.0_dp, zero = 0.0_dp

      INTEGER                                            :: handle, i, j, k, ldim, ndim, p, q
      REAL(KIND=dp)                                      :: my_fac, scaleD, scaleN, square_fac
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
         POINTER                                         :: local_data
      TYPE(cp_fm_type)                                   :: Dpq, fin_p
      TYPE(cp_fm_type), DIMENSION(2)                     :: mult_p
      TYPE(cp_fm_type), TARGET                           :: Npq, T1, T2, Tres

      CALL timeset(routineN, handle)
      p = npade
      q = npade !p==q seems to be necessary for the rest of the code

      CALL cp_fm_get_info(matrix, local_data=local_data, ncol_local=ldim, nrow_global=ndim)
      square_fac = 2.0_dp**REAL(nsquare, dp)

      CALL cp_fm_create(Dpq, &
                        matrix_struct=matrix%matrix_struct, &
                        name="Dpq")

      CALL cp_fm_create(Npq, &
                        matrix_struct=Dpq%matrix_struct, &
                        name="Npq")

      CALL cp_fm_create(T1, &
                        matrix_struct=Dpq%matrix_struct, &
                        name="T1")

      CALL cp_fm_create(T2, &
                        matrix_struct=T1%matrix_struct, &
                        name="T2")

      CALL cp_fm_create(Tres, &
                        matrix_struct=T1%matrix_struct, &
                        name="Tres")

      DO i = 1, ldim
         T2%local_data(:, i) = local_data(:, i)/square_fac
      END DO

      CALL cp_fm_to_fm(T2, T1)
      CALL cp_fm_set_all(Npq, zero, one)
      CALL cp_fm_set_all(Dpq, zero, one)

      DO i = 1, ldim
         Npq%local_data(:, i) = Npq%local_data(:, i) + 0.5_dp*local_data(:, i)
         Dpq%local_data(:, i) = Dpq%local_data(:, i) - 0.5_dp*local_data(:, i)
      END DO

      mult_p(1) = T2
      mult_p(2) = Tres

      IF (npade >= 2) THEN
         DO j = 2, npade
            my_fac = (-1.0_dp)**j
            scaleN = fac(p + q - j)*fac(p)/(fac(p + q)*fac(j)*fac(p - j))
            scaleD = my_fac*fac(p + q - j)*fac(q)/(fac(p + q)*fac(j)*fac(q - j))
            CALL parallel_gemm("N", "N", ndim, ndim, ndim, one, mult_p(MOD(j, 2) + 1), T1, &
                               zero, mult_p(MOD(j + 1, 2) + 1))

            DO k = 1, ldim
               Npq%local_data(:, k) = Npq%local_data(:, k) + scaleN*mult_p(MOD(j + 1, 2) + 1)%local_data(:, k)
               Dpq%local_data(:, k) = Dpq%local_data(:, k) + scaleD*mult_p(MOD(j + 1, 2) + 1)%local_data(:, k)
            END DO
         END DO
      END IF

      CALL cp_fm_solve(Dpq, Npq)

      mult_p(2) = Npq
      mult_p(1) = T1
      IF (nsquare > 0) THEN
         DO i = 1, nsquare
            CALL parallel_gemm("N", "N", ndim, ndim, ndim, one, mult_p(MOD(i, 2) + 1), mult_p(MOD(i, 2) + 1), zero, &
                               mult_p(MOD(i + 1, 2) + 1))
         END DO
         fin_p = mult_p(MOD(nsquare + 1, 2) + 1)
      ELSE
         fin_p = Npq
      END IF

      DO k = 1, ldim
         exp_H%local_data(:, k) = fin_p%local_data(:, k)
      END DO

      CALL cp_fm_release(Npq)
      CALL cp_fm_release(Dpq)
      CALL cp_fm_release(T1)
      CALL cp_fm_release(T2)
      CALL cp_fm_release(Tres)
      CALL timestop(handle)

   END SUBROUTINE exp_pade_real

! ***************************************************************************************************
!> \brief exponential of a complex matrix,
!>        calculated using arnoldi subspace method (directly applies to the MOs)
!> \param mos_old ...
!> \param mos_new ...
!> \param eps_exp ...
!> \param Hre ...
!> \param Him ...
!> \param mos_next ...
!> \param narn_old ...
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************

   SUBROUTINE arnoldi(mos_old, mos_new, eps_exp, Hre, Him, mos_next, narn_old)

      TYPE(cp_fm_type), DIMENSION(2)                     :: mos_old, mos_new
      REAL(KIND=dp), INTENT(in)                          :: eps_exp
      TYPE(cp_fm_type), INTENT(IN), OPTIONAL             :: Hre
      TYPE(cp_fm_type), INTENT(IN)                       :: Him
      TYPE(cp_fm_type), DIMENSION(2), OPTIONAL           :: mos_next
      INTEGER, INTENT(inout)                             :: narn_old

      CHARACTER(len=*), PARAMETER                        :: routineN = 'arnoldi'
      REAL(KIND=dp), PARAMETER                           :: rone = 1.0_dp, rzero = 0.0_dp

      INTEGER                                            :: handle, i, icol_local, idim, info, j, l, &
                                                            mydim, nao, narnoldi, ncol_local, &
                                                            newdim, nmo, npade, pade_step
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: ipivot
      INTEGER, DIMENSION(:), POINTER                     :: col_indices, col_procs
      LOGICAL                                            :: convergence, double_col, double_row
      REAL(dp), ALLOCATABLE, DIMENSION(:)                :: last_norm, norm1, results
      REAL(dp), ALLOCATABLE, DIMENSION(:, :)             :: D, mat1, mat2, mat3, N
      REAL(dp), ALLOCATABLE, DIMENSION(:, :, :)          :: H_approx, H_approx_save
      REAL(KIND=dp)                                      :: conv_norm, prefac, scaleD, scaleN
      TYPE(cp_fm_struct_type), POINTER                   :: mo_struct, newstruct
      TYPE(cp_fm_type), ALLOCATABLE, DIMENSION(:)        :: V_mats
      TYPE(mp_comm_type)                                 :: col_group
      TYPE(mp_para_env_type), POINTER                    :: para_env

      CALL timeset(routineN, handle)
      para_env => mos_new(1)%matrix_struct%para_env

      CALL cp_fm_get_info(mos_old(1), ncol_local=ncol_local, col_indices=col_indices, &
                          nrow_global=nao, ncol_global=nmo, matrix_struct=mo_struct)
      narnoldi = MIN(18, nao)

      ALLOCATE (results(ncol_local))
      ALLOCATE (norm1(ncol_local))
      ALLOCATE (V_mats(narnoldi + 1))
      ALLOCATE (last_norm(ncol_local))
      ALLOCATE (H_approx(narnoldi, narnoldi, ncol_local))
      ALLOCATE (H_approx_save(narnoldi, narnoldi, ncol_local))
      col_procs => mo_struct%context%blacs2mpi(:, mo_struct%context%mepos(2))
      CALL col_group%from_reordering(para_env, col_procs)

      double_col = .TRUE.
      double_row = .FALSE.
      CALL cp_fm_struct_double(newstruct, mo_struct, mo_struct%context, double_col, double_row)
      H_approx_save = rzero

      DO i = 1, narnoldi + 1
         CALL cp_fm_create(V_mats(i), matrix_struct=newstruct, &
                           name="V_mat"//cp_to_string(i))
      END DO
      CALL cp_fm_get_info(V_mats(1), ncol_global=newdim)

      norm1 = 0.0_dp
!$OMP PARALLEL DO PRIVATE(icol_local) DEFAULT(NONE) SHARED(V_mats,norm1,mos_old,ncol_local)
      DO icol_local = 1, ncol_local
         V_mats(1)%local_data(:, icol_local) = mos_old(1)%local_data(:, icol_local)
         V_mats(1)%local_data(:, icol_local + ncol_local) = mos_old(2)%local_data(:, icol_local)
         norm1(icol_local) = SUM(V_mats(1)%local_data(:, icol_local)**2) &
                             + SUM(V_mats(1)%local_data(:, icol_local + ncol_local)**2)
      END DO

      CALL col_group%sum(norm1)
      !!! normalize the mo vectors
      norm1(:) = SQRT(norm1(:))

!$OMP PARALLEL DO PRIVATE(icol_local) DEFAULT(NONE) SHARED(V_mats,norm1,ncol_local)
      DO icol_local = 1, ncol_local
         V_mats(1)%local_data(:, icol_local) = V_mats(1)%local_data(:, icol_local)/norm1(icol_local)
         V_mats(1)%local_data(:, icol_local + ncol_local) = &
            V_mats(1)%local_data(:, icol_local + ncol_local)/norm1(icol_local)
      END DO

      ! arnoldi subspace procedure to get H_approx
      DO i = 2, narnoldi + 1
         !Be careful, imaginary matrix multiplied with complex. Unfortunately requires a swap of arrays afterwards
         CALL parallel_gemm("N", "N", nao, newdim, nao, 1.0_dp, Him, V_mats(i - 1), 0.0_dp, V_mats(i))

!$OMP PARALLEL DO PRIVATE(icol_local) DEFAULT(NONE) SHARED(mos_new,V_mats,ncol_local,i)
         DO icol_local = 1, ncol_local
            mos_new(1)%local_data(:, icol_local) = V_mats(i)%local_data(:, icol_local)
            V_mats(i)%local_data(:, icol_local) = -V_mats(i)%local_data(:, icol_local + ncol_local)
            V_mats(i)%local_data(:, icol_local + ncol_local) = mos_new(1)%local_data(:, icol_local)
         END DO

         IF (PRESENT(Hre)) THEN
            CALL parallel_gemm("N", "N", nao, newdim, nao, 1.0_dp, Hre, V_mats(i - 1), 1.0_dp, V_mats(i))
         END IF

         DO l = 1, i - 1
!$OMP PARALLEL DO DEFAULT(NONE) SHARED(results,V_mats,ncol_local,l,i)
            DO icol_local = 1, ncol_local
               results(icol_local) = SUM(V_mats(l)%local_data(:, icol_local)*V_mats(i)%local_data(:, icol_local)) + &
                                     SUM(V_mats(l)%local_data(:, icol_local + ncol_local)* &
                                         V_mats(i)%local_data(:, icol_local + ncol_local))
            END DO

            CALL col_group%sum(results)

!$OMP PARALLEL DO DEFAULT(NONE) SHARED(H_approx_save,V_mats,ncol_local,l,i,results)
            DO icol_local = 1, ncol_local
               H_approx_save(l, i - 1, icol_local) = results(icol_local)
               V_mats(i)%local_data(:, icol_local) = V_mats(i)%local_data(:, icol_local) - &
                                                     results(icol_local)*V_mats(l)%local_data(:, icol_local)
               V_mats(i)%local_data(:, icol_local + ncol_local) = &
                  V_mats(i)%local_data(:, icol_local + ncol_local) - &
                  results(icol_local)*V_mats(l)%local_data(:, icol_local + ncol_local)
            END DO
         END DO

!$OMP PARALLEL DO DEFAULT(NONE) SHARED(ncol_local,V_mats,results,i)
         DO icol_local = 1, ncol_local
            results(icol_local) = SUM(V_mats(i)%local_data(:, icol_local)**2) + &
                                  SUM(V_mats(i)%local_data(:, icol_local + ncol_local)**2)
         END DO

         CALL col_group%sum(results)

         IF (i <= narnoldi) THEN

!$OMP PARALLEL DO DEFAULT(NONE) SHARED(H_approx_save,last_norm,V_mats,ncol_local,i,results)
            DO icol_local = 1, ncol_local
               H_approx_save(i, i - 1, icol_local) = SQRT(results(icol_local))
               last_norm(icol_local) = SQRT(results(icol_local))
               V_mats(i)%local_data(:, icol_local) = V_mats(i)%local_data(:, icol_local)/SQRT(results(icol_local))
               V_mats(i)%local_data(:, icol_local + ncol_local) = &
                  V_mats(i)%local_data(:, icol_local + ncol_local)/SQRT(results(icol_local))
            END DO
         ELSE
!$OMP PARALLEL DO DEFAULT(NONE) SHARED(ncol_local,last_norm,results)
            DO icol_local = 1, ncol_local
               last_norm(icol_local) = SQRT(results(icol_local))
            END DO
         END IF

         H_approx(:, :, :) = H_approx_save

         ! PADE approximation for exp(H_approx), everything is done locally

         convergence = .FALSE.
         IF (i >= narn_old) THEN
            npade = 9
            mydim = MIN(i, narnoldi)
            ALLOCATE (ipivot(mydim))
            ALLOCATE (mat1(mydim, mydim))
            ALLOCATE (mat2(mydim, mydim))
            ALLOCATE (mat3(mydim, mydim))
            ALLOCATE (N(mydim, mydim))
            ALLOCATE (D(mydim, mydim))
            DO icol_local = 1, ncol_local
               DO idim = 1, mydim
                  DO j = 1, mydim
                     mat1(idim, j) = H_approx(idim, j, icol_local)/16.0_dp
                     mat3(idim, j) = mat1(idim, j)
                  END DO
               END DO
               N = 0.0_dp
               D = 0.0_dp
               DO idim = 1, mydim
                  N(idim, idim) = rone
                  D(idim, idim) = rone
               END DO
               N(:, :) = N + 0.5_dp*mat1
               D(:, :) = D - 0.5_dp*mat1
               pade_step = 1
               DO idim = 1, 4
                  pade_step = pade_step + 1
                  CALL dgemm("N", 'N', mydim, mydim, mydim, rone, mat1(1, 1), &
                             mydim, mat3(1, 1), mydim, rzero, mat2(1, 1), mydim)
                  scaleN = REAL(fac(2*npade - pade_step)*fac(npade)/ &
                                (fac(2*npade)*fac(pade_step)*fac(npade - pade_step)), dp)
                  scaleD = REAL((-1.0_dp)**pade_step*fac(2*npade - pade_step)*fac(npade)/ &
                                (fac(2*npade)*fac(pade_step)*fac(npade - pade_step)), dp)
                  N(:, :) = N + scaleN*mat2
                  D(:, :) = D + scaleD*mat2
                  pade_step = pade_step + 1
                  CALL dgemm("N", 'N', mydim, mydim, mydim, rone, mat2(1, 1), &
                             mydim, mat1(1, 1), mydim, rzero, mat3(1, 1), mydim)
                  scaleN = REAL(fac(2*npade - pade_step)*fac(npade)/ &
                                (fac(2*npade)*fac(pade_step)*fac(npade - pade_step)), dp)
                  scaleD = REAL((-1.0_dp)**pade_step*fac(2*npade - pade_step)*fac(npade)/ &
                                (fac(2*npade)*fac(pade_step)*fac(npade - pade_step)), dp)
                  N(:, :) = N + scaleN*mat3
                  D(:, :) = D + scaleD*mat3
               END DO

               CALL dgetrf(mydim, mydim, D(1, 1), mydim, ipivot, info)
               CALL dgetrs("N", mydim, mydim, D(1, 1), mydim, ipivot, N, mydim, info)
               CALL dgemm("N", 'N', mydim, mydim, mydim, rone, N(1, 1), mydim, N(1, 1), mydim, rzero, mat1(1, 1), mydim)
               CALL dgemm("N", 'N', mydim, mydim, mydim, rone, mat1(1, 1), mydim, mat1(1, 1), mydim, rzero, N(1, 1), mydim)
               CALL dgemm("N", 'N', mydim, mydim, mydim, rone, N(1, 1), mydim, N(1, 1), mydim, rzero, mat1(1, 1), mydim)
               CALL dgemm("N", 'N', mydim, mydim, mydim, rone, mat1(1, 1), mydim, mat1(1, 1), mydim, rzero, N(1, 1), mydim)
               DO idim = 1, mydim
                  DO j = 1, mydim
                     H_approx(idim, j, icol_local) = N(idim, j)
                  END DO
               END DO
            END DO
            ! H_approx is exp(H_approx) right now, calculate new MOs and check for convergence
            conv_norm = 0.0_dp
            results = 0.0_dp
            DO icol_local = 1, ncol_local
               results(icol_local) = last_norm(icol_local)*H_approx(i - 1, 1, icol_local)
               conv_norm = MAX(conv_norm, ABS(results(icol_local)))
            END DO

            CALL para_env%max(conv_norm)

            IF (conv_norm < eps_exp .OR. i > narnoldi) THEN

               mos_new(1)%local_data = rzero
               mos_new(2)%local_data = rzero
               DO icol_local = 1, ncol_local
                  DO idim = 1, mydim
                     prefac = H_approx(idim, 1, icol_local)*norm1(icol_local)
                     mos_new(1)%local_data(:, icol_local) = mos_new(1)%local_data(:, icol_local) + &
                                                            V_mats(idim)%local_data(:, icol_local)*prefac
                     mos_new(2)%local_data(:, icol_local) = mos_new(2)%local_data(:, icol_local) + &
                                                            V_mats(idim)%local_data(:, icol_local + ncol_local)*prefac
                  END DO
               END DO

               IF (PRESENT(mos_next)) THEN
                  DO icol_local = 1, ncol_local
                     DO idim = 1, mydim
                        DO j = 1, mydim
                           N(idim, j) = H_approx(idim, j, icol_local)
                        END DO
                     END DO
                     CALL dgemm("N", 'N', mydim, mydim, mydim, rone, N(1, 1), mydim, N(1, 1), mydim, rzero, mat1(1, 1), mydim)
                     DO idim = 1, mydim
                        DO j = 1, mydim
                           H_approx(idim, j, icol_local) = mat1(idim, j)
                        END DO
                     END DO
                  END DO
                  mos_next(1)%local_data = rzero
                  mos_next(2)%local_data = rzero
                  DO icol_local = 1, ncol_local
                     DO idim = 1, mydim
                        prefac = H_approx(idim, 1, icol_local)*norm1(icol_local)
                        mos_next(1)%local_data(:, icol_local) = &
                           mos_next(1)%local_data(:, icol_local) + &
                           V_mats(idim)%local_data(:, icol_local)*prefac
                        mos_next(2)%local_data(:, icol_local) = &
                           mos_next(2)%local_data(:, icol_local) + &
                           V_mats(idim)%local_data(:, icol_local + ncol_local)*prefac
                     END DO
                  END DO
               END IF
               IF (conv_norm < eps_exp) THEN
                  convergence = .TRUE.
                  narn_old = i - 1
               END IF
            END IF

            DEALLOCATE (ipivot)
            DEALLOCATE (mat1)
            DEALLOCATE (mat2)
            DEALLOCATE (mat3)
            DEALLOCATE (N)
            DEALLOCATE (D)
         END IF
         IF (convergence) EXIT

      END DO
      CPWARN_IF(.NOT. convergence, "ARNOLDI method did not converge")
      !deallocate all work matrices

      CALL cp_fm_release(V_mats)
      CALL cp_fm_struct_release(newstruct)
      CALL col_group%free()

      DEALLOCATE (H_approx)
      DEALLOCATE (H_approx_save)
      DEALLOCATE (results)
      DEALLOCATE (norm1)
      DEALLOCATE (last_norm)
      CALL timestop(handle)
   END SUBROUTINE arnoldi

END MODULE matrix_exp
